import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import custom_fwd, custom_bwd
from typing import Any, Dict, List, Optional
from torch import Tensor


class ParallelLinear(torch.autograd.Function):

    # @staticmethod
    @custom_fwd(cast_inputs=torch.float16)
    def forward(ctx, input, expert_size, weight, bias=None):
        output = ParallelLinear.forward_scriptable(input, expert_size, weight, bias)
        # assert torch.allclose(ParallelLinear._forward_scriptable(input, expert_size, weight, bias),  output)
        ctx.save_for_backward(input, expert_size, weight, bias)
        return output

    # @staticmethod
    # @torch.jit.script
    def forward_scriptable(input: Tensor, expert_size: Tensor, weight: Tensor, bias: Optional[Tensor]):
        # print('input.shape:', input.shape) # (B*T*K, C)
        # print('expert_size.shape:', expert_size.shape) # (E,)
        # print('expert_size.sum():', expert_size.sum()) # B*T*K
        # print('weight.shape:', weight.shape) # (E, C, dim_head)
        # print('bias.shape:', bias.shape) # None

        output_buf: Tensor = torch.empty((input.size(0), weight.size(2)), device=input.device, dtype=input.dtype)
        # print('output_buf.shape:', output_buf.shape) # (B*T*K, dim_head)
        num_linears = weight.size(0) # E

        expert_size_list: List[int] = expert_size.tolist()
        input_list = input.split(expert_size_list, dim=0)
        # print('type(input_list):', type(input_list)) # tuple
        # print('len(input_list):', len(input_list)) # E
        # print()
        # for i in range(len(input_list)):
        #     print('input_list[i].shape:', input_list[i].shape)
        '''
        (1182, 768)
        (1179, 768)
        (0, 768)
        (0, 768)
        ...
        (431, 768)
        (1182, 768)
        (1182, 768)
        (0, 768)
        '''
        
        output_buf_list = output_buf.split(expert_size_list)
        # print('type(output_buf_list):', type(output_buf_list)) # tuple
        # print('len(output_buf_list):', len(output_buf_list)) # E
        '''
        for i in range(len(output_buf_list)):
            print('output_buf_list[i].shape:', output_buf_list[i].shape)
        (1182, 128)
        (1179, 128)
        (0, 128)
        (0, 128)
        ...
        (431, 128)
        (1182, 128)
        (1182, 128)
        (0, 128)
        '''

        # print('num_linears:', num_linears) # E
        for i in range(num_linears):
            torch.mm(input_list[i], weight[i], out=output_buf_list[i]) #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

        if bias is not None:
            for i in range(num_linears):
                output_buf_list[i].add_(bias[i])

        output = output_buf
        # print('type(output):', type(output)) # torch.Tensor
        # print('output.shape:', output.shape) # (B*T*K, dim_head)
        return output

    # @staticmethod
    # @custom_bwd
    def backward(ctx, grad_out):
        input, expert_size, weight, bias = ctx.saved_tensors
        return ParallelLinear.backward_scriptable(grad_out, input, expert_size, weight, bias)

    # @staticmethod
    # @torch.jit.script
    def backward_scriptable(grad_out: Tensor,
                 input: Tensor, expert_size: Tensor,
                 weight: Tensor, bias: Optional[Tensor]):
        num_linears = weight.size(0)
        expert_size_list: List[int] = expert_size.tolist()
        input_list = input.t().split(expert_size_list, dim=1)
        grad_list = grad_out.split(expert_size_list, dim=0)

        d_input_buf = torch.empty_like(input)
        d_input_buf_list = d_input_buf.split(expert_size_list, dim=0)
        d_weight_buf = torch.empty_like(weight)

        weight_t = weight.permute(0, 2, 1)

        for i in range(num_linears):
            torch.mm(grad_list[i], weight_t[i], out=d_input_buf_list[i])
            torch.mm(input_list[i], grad_list[i], out=d_weight_buf[i])

        d_input = d_input_buf
        d_weight = d_weight_buf

        if bias is not None:
            d_bias_buf = torch.empty_like(bias)
            for i in range(num_linears):
                torch.sum(grad_list[i], dim=0, keepdim=False, out=d_bias_buf[i])
            d_bias = d_bias_buf
        else:
            d_bias = None

        return d_input, None, d_weight, d_bias


class ParallelExperts(nn.Module):
    def __init__(self, num_experts, input_size, output_size, bias=False) -> None:
        super().__init__()
        self.w = nn.Parameter(torch.empty(num_experts, input_size, output_size))
        if bias:
            self.b = nn.Parameter(torch.zeros(num_experts, output_size))
        else:
            self.b = None
        self.reset_parameters()

    def extra_repr(self):
        return 'num_experts={}, input_size={}, output_size={}'.format(
            self.w.size(0), self.w.size(1), self.w.size(2))

    def reset_parameters(self) -> None:
        # std = math.sqrt(2.0 / float(self.w.size(1) + self.w.size(2)))
        # a = math.sqrt(3.0) * std
        nn.init.uniform_(self.w, -1. / self.w.size(1), 1. / self.w.size(1))
        if self.b is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.w[0])
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            nn.init.uniform_(self.b, -bound, bound)

    def forward(self, inputs, expert_size):
        results = ParallelLinear.apply(inputs, expert_size, self.w, self.b)
        return results

# import math

# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from torch.cuda.amp import custom_fwd, custom_bwd


# class ParallelLinear(torch.autograd.Function):
#     @staticmethod
#     @custom_fwd
#     def forward(ctx, input, expert_size, weight, bias=None):
#         output_list = []
#         expert_size_list = expert_size.tolist()
#         input_list = input.split(expert_size_list, dim=0)
#         for i in range(weight.size(0)):
#             if bias is not None:
#                 o_i = torch.mm(input_list[i], weight[i]) + bias[i]
#             else:
#                 o_i = torch.mm(input_list[i], weight[i])
#             output_list.append(o_i)
#         output = torch.cat(output_list, dim=0)
#         variables = (input, expert_size, weight, bias)
#         ctx.save_for_backward(*variables)
#         return output

#     @staticmethod
#     @custom_bwd
#     def backward(ctx, grad_out):
#         input, expert_size, weight, bias = ctx.saved_tensors
#         num_linears = weight.size(0)

#         expert_size_list = expert_size.tolist()
#         input_list = input.split(expert_size_list, dim=0)
#         grad_list = grad_out.split(expert_size_list, dim=0)

#         d_input_list = []
#         for i in range(num_linears):
#             d_input_list.append(torch.einsum('bi,ji->bj', grad_list[i], weight[i]))
#         d_input = torch.cat(d_input_list, dim=0)

#         d_weight_list = []
#         for i in range(num_linears):
#             d_weight_list.append(torch.einsum('bi,bj->ij', input_list[i], grad_list[i]))
#         d_weight = torch.stack(d_weight_list, dim=0)

#         if bias is not None:
#             d_bias_list = []
#             for i in range(num_linears):
#                 d_bias_list.append(grad_list[i].sum(0))
#             d_bias = torch.stack(d_bias_list, dim=0)
#         else:
#             d_bias = None
#         return d_input, None, d_weight, d_bias


# class ParallelExperts(nn.Module):
#     def __init__(self, num_experts, input_size, output_size, bias=False) -> None:
#         super().__init__()
#         self.w = nn.Parameter(torch.empty(num_experts, input_size, output_size))
#         if bias:
#             self.b = nn.Parameter(torch.zeros(num_experts, output_size))
#         else:
#             self.b = None

#         self.reset_parameters()

#     def reset_parameters(self) -> None:
#         std = math.sqrt(2.0 / float(self.w.size(1) + self.w.size(2)))
#         a = math.sqrt(3.0) * std
#         nn.init.uniform_(self.w, -a, a)

#     def forward(self, inputs, expert_size):
#         results = ParallelLinear.apply(inputs, expert_size, self.w, self.b)
#         return results
